import numpy as np
import numpy.linalg as LA
from sympy import re, im, I, E, Symbol, sqrt
import argparse


def g_CauchyK_num(S):
    z = Symbol('z')
    ret = 0
    M = len(S)
    
    for j in range(M):

        ret += 1/(z + S[j] - I*np.sqrt(1/(2*M)) )
        ret += 1/(z - S[j]- I*np.sqrt(1/(2*M)) )
    
    return ret/(2*M)

def Estimator(S_s, gX, gS, SNR, alpha, N, e_oracle):
    
    M = len(S_s)
    
    output_X = np.zeros(N)
    
    dfr = 32
    if SNR > 2:
        dfr = 64
    elif SNR > 4:
        dfr = 128
        
    z = Symbol('z')
    
    for i in range(M):
        
        if i>= 500 and i%25==0:
            dfr = dfr * (2 ** (-0.8) )
        zz = S_s[i] -  I*np.sqrt(dfr/(2*M))
        gS_eval = gS.subs(z,zz).evalf()
        
        zeta = (1/alpha) * gS_eval
        
        Z = (zz/zeta -1)/SNR
        
        Est = gX.subs(z,sqrt(Z)).evalf() + gX.subs(z,-sqrt(Z)).evalf()
        
        output_X[i] = alpha * im(((Est/zeta)/(2*SNR*im(gS_eval))).evalf())
        
        print(i)
        print(output_X[i])
        print(e_oracle[i])
        print('\n')
    
    dfr = 32
    if SNR > 2:
        dfr = 64
    elif SNR > 4:
        dfr = 128
        
    eps = np.sqrt(dfr/(2*M))
    x = np.sqrt(dfr/(2*M))
    zz = x - I*eps
    gS_eval = gS.subs(z, zz).evalf()
    zeta = (1/alpha) * gS_eval
    Z =  (zz/zeta -1)/SNR
    Est = gX.subs(z,sqrt(Z)).evalf() + gX.subs(z,-sqrt(Z)).evalf()
    
    output_X[M:] = ((alpha/(alpha-1))*(1/2) * im(Est/zeta) * ( (x**2 + eps**2)/eps) )/SNR
    
    
    for i in range(M,N):
        print(output_X[i])
        print(e_oracle[i])
        print('\n')
        
    return output_X


def main():
    
    z = Symbol('z')
    p = argparse.ArgumentParser()

    p.add_argument('-a', type=float)
    p.add_argument('-s', type=float)
    p.add_argument('-p', type=str)
    
    args = p.parse_args()
    
    a = args.a
    prior = args.p
    SNR = args.s

    N = 2000
    M = int(N/a)
    
    Ex = 10
    
        
    E_X_oracle = np.zeros(Ex)
    E_X_RIE = np.zeros(Ex)
    
    for i in range(Ex):
        
        if prior == 'Wigner':
            X = np.triu(np.random.normal(0, 1, (N,N)), 1)
            X = X + np.transpose(X) + np.diag(np.random.normal(loc=0, scale=np.sqrt(2), size=(N)))
            X = X/np.sqrt(N)
            X = X + 3*np.eye(N)
                
            gX =  (z - 3 - sqrt(z-5)* sqrt(z-1))/2
                
        elif prior == 'Wishart':
            
            X = np.random.randn(N,4*N)
            X = X@np.transpose(X)/N
            
            gX = (z - 3 - sqrt(z-1)*sqrt(z-9))/(2*z)

    
        ## Noise
        Y = np.random.randn(N,M)
        Y = Y/np.sqrt(N)
    
        W = np.random.randn(N,M)
        W = W/np.sqrt(N)


        ### Observation
        S = np.sqrt(SNR) * X @ Y + W
    
        ### SVD
        U_s, S_s , Vh_s = LA.svd(S)

        gS = g_CauchyK_num(S_s)

        ### Oracle Estimator for X & X^2
        e_hat_X_oracle = np.zeros(N)
        
        X_norm = LA.norm(X)**2

        
        for k in range(N):
            e_hat_X_oracle[k] = np.transpose(U_s[:,k])@X@U_s[:,k]
                
        X_hat_oracle = U_s@np.diag(e_hat_X_oracle)@np.transpose(U_s)
        
        E_X_oracle[i] = ( LA.norm(X-X_hat_oracle)**2) / X_norm
    
        #### RIE for X
        e_hat_X = Estimator(S_s, gX, gS, SNR, a, N, e_hat_X_oracle)
        
        X_hat = U_s@np.diag(e_hat_X)@np.transpose(U_s)
        E_X_RIE[i] = ( LA.norm(X-X_hat)**2 ) / X_norm
        
        
        ###########test
        Ev_x, _ = LA.eigh(X)
        
        print( LA.norm(Ev_x[:M]-e_hat_X_oracle[:M])**2/X_norm - LA.norm(Ev_x[:M]-e_hat_X[:M])**2/X_norm )
        print( LA.norm(Ev_x[M:]-e_hat_X_oracle[M:])**2/X_norm - LA.norm(Ev_x[M:]-e_hat_X[M:])**2/X_norm )
    
        with open('start.txt', 'a') as f:
            f.write(prior + str(SNR) + str(E_X_oracle[i]) + '\n' + str(E_X_RIE[i])+'\n')

    filename = 'X-'+prior+'_SNR='+str(SNR)+'_Oracle.npy'
    np.save( filename, E_X_oracle)

    filename = 'X-'+prior+'_SNR='+str(SNR)+'_RIE.npy'
    np.save( filename, E_X_RIE)
    

    

#
if __name__ == "__main__":
    main()
    
